// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

TEST(ContribOpTest, MaxPoolWithMask) {
  OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

  test.AddAttribute("auto_pad", "");
  test.AddAttribute("strides", std::vector<int64_t>{1, 1});
  test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
  test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

  std::vector<float> x_vals = {
      0.19151945412158966, 0.6221087574958801, 0.43772774934768677, 0.7853586077690125, 0.7799758315086365, 0.27259260416030884, 0.2764642536640167, 0.801872193813324,
      0.9581393599510193, 0.8759326338768005, 0.35781726241111755, 0.5009950995445251, 0.683462917804718, 0.7127020359039307, 0.37025076150894165, 0.5611962080001831,
      0.5030831694602966, 0.013768449425697327, 0.772826611995697, 0.8826411962509155, 0.36488598585128784, 0.6153962016105652, 0.07538124173879623, 0.3688240051269531,
      0.9331400990486145, 0.6513781547546387, 0.39720258116722107, 0.7887301445007324, 0.3168361186981201, 0.5680986642837524, 0.8691273927688599, 0.4361734092235565,
      0.802147626876831, 0.14376682043075562, 0.7042609453201294, 0.7045813202857971, 0.2187921106815338, 0.9248676300048828, 0.44214075803756714, 0.9093159437179565,
      0.05980922281742096, 0.18428708612918854, 0.047355279326438904, 0.6748809218406677, 0.5946247577667236, 0.5333101749420166, 0.043324064463377, 0.5614330768585205,
      0.32966843247413635, 0.5029668211936951, 0.11189431697130203, 0.6071937084197998, 0.5659446716308594, 0.006764062214642763, 0.617441713809967, 0.912122905254364,
      0.7905241250991821, 0.9920814633369446, 0.9588017463684082, 0.7919641137123108, 0.2852509617805481, 0.6249167323112488, 0.47809380292892456, 0.19567517936229706,

      0.382317453622818, 0.053873684257268906, 0.45164841413497925, 0.9820047616958618, 0.12394270300865173, 0.1193808987736702, 0.7385230660438538, 0.587303638458252,
      0.47163254022598267, 0.10712681710720062, 0.22921857237815857, 0.8999651670455933, 0.41675353050231934, 0.5358516573905945, 0.0062085166573524475, 0.3006417155265808,
      0.43689316511154175, 0.6121490001678467, 0.9181980490684509, 0.625736653804779, 0.7059975862503052, 0.14983370900154114, 0.7460634112358093, 0.8310070037841797,
      0.6337257623672485, 0.4383098781108856, 0.15257278084754944, 0.5684096217155457, 0.5282242894172668, 0.9514287710189819, 0.48035916686058044, 0.5025595426559448,
      0.5368781685829163, 0.8192020654678345, 0.05711563676595688, 0.6694217324256897, 0.7671166062355042, 0.7081153392791748, 0.7968671917915344, 0.5577608346939087,
      0.9658365249633789, 0.14715689420700073, 0.02964700013399124, 0.5938934683799744, 0.11406569927930832, 0.9508098363876343, 0.32570740580558777, 0.19361868500709534,
      0.4578116536140442, 0.9204025864601135, 0.8790691494941711, 0.252615749835968, 0.34800878167152405, 0.18258872628211975, 0.9017960429191589, 0.7065281867980957,
      0.7266584634780884, 0.900087833404541, 0.7791637778282166, 0.5991547703742981, 0.29112523794174194, 0.1513952612876892, 0.33517464995384216, 0.6575517654418945,

      0.07334254682064056, 0.055006396025419235, 0.32319480180740356, 0.5904818177223206, 0.8538985848426819, 0.2870624363422394, 0.17306722700595856, 0.13402120769023895,
      0.9946538209915161, 0.1794978678226471, 0.3175468146800995, 0.568291425704956, 0.009348574094474316, 0.9006485939025879, 0.9772414565086365, 0.5568946599960327,
      0.08477384597063065, 0.3330024778842926, 0.7284286618232727, 0.14243537187576294, 0.5524689555168152, 0.2730432450771332, 0.9744951128959656, 0.6677868962287903,
      0.2556532919406891, 0.1083114966750145, 0.7761807441711426, 0.7824779748916626, 0.7616038918495178, 0.9144031405448914, 0.6586228013038635, 0.568367600440979,
      0.20175568759441376, 0.6982963681221008, 0.952195405960083, 0.8899632692337036, 0.9935673475265503, 0.8187035322189331, 0.5451221466064453, 0.45125406980514526,
      0.8905571699142456, 0.9732648134231567, 0.5934113264083862, 0.36607450246810913, 0.3230946958065033, 0.8714232444763184, 0.2156340628862381, 0.7349451780319214,
      0.36561909317970276, 0.8016026020050049, 0.7827355861663818, 0.7013553977012634, 0.6227765679359436, 0.4936826527118683, 0.8405377268791199, 0.7120969891548157,
      0.4439089894294739, 0.031034860759973526, 0.36323976516723633, 0.7307217717170715, 0.475566565990448, 0.3444169759750366, 0.6408804059028625, 0.12620532512664795};
  std::vector<int64_t> x_dims = {1, 3, 8, 8};
  std::vector<int32_t> m_vals = {
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,

      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,

      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0,
      1, 1, 1, 0, 0, 0, 0, 0};

  std::vector<int64_t> expected_dims = {1, 3, 1, 1};
  std::vector<float> expected_vals = {0.9920814633369446, 0.9658365249633789, 0.9946538209915161};

  test.AddInput<float>("X", x_dims, x_vals);
  test.AddInput<int32_t>("M", x_dims, m_vals);
  test.AddOutput<float>("Y", expected_dims, expected_vals);
  test.Run();
}

}  // namespace test
}  // namespace onnxruntime
